/**
 * 
 */
package edu.berkeley.nlp.discPCFG;

import java.io.Serializable;

import edu.berkeley.nlp.PCFGLA.BinaryRule;
import edu.berkeley.nlp.PCFGLA.Grammar;
import edu.berkeley.nlp.PCFGLA.HierarchicalBinaryRule;
import edu.berkeley.nlp.PCFGLA.HierarchicalGrammar;
import edu.berkeley.nlp.PCFGLA.HierarchicalLexicon;
import edu.berkeley.nlp.PCFGLA.HierarchicalUnaryRule;
import edu.berkeley.nlp.PCFGLA.Rule;
import edu.berkeley.nlp.PCFGLA.SimpleLexicon;
import edu.berkeley.nlp.PCFGLA.SpanPredictor;
import edu.berkeley.nlp.PCFGLA.UnaryRule;
import edu.berkeley.nlp.syntax.StateSet;
import edu.berkeley.nlp.math.DoubleArrays;
import edu.berkeley.nlp.math.SloppyMath;
import edu.berkeley.nlp.util.ArrayUtil;

/**
 * similar to cascading linearizer but doesnt compute the grammars explicitly
 * instead uses hierarchical rules and merges back unused splits
 * @author petrov
 *
 */
public class HierarchicalLinearizer extends DefaultLinearizer {

	private static final long serialVersionUID = 1L;
	
	HierarchicalGrammar grammar;
  HierarchicalLexicon lexicon;

	int finalLevel;
	int[][] lexiconMapping;
	int[][][] unaryMapping;
	int[][][][] binaryMapping;

	public HierarchicalLinearizer(){}
	
	/**
	 * @param grammar
	 * @param lexicon
	 */
	public HierarchicalLinearizer(Grammar grammar, SimpleLexicon lexicon, SpanPredictor sp,  int fLevel) {
		this.grammar = (HierarchicalGrammar)grammar;
		this.lexicon = (HierarchicalLexicon)lexicon;
		this.spanPredictor = sp;
		this.finalLevel = fLevel;
		this.nSubstates = (int)ArrayUtil.max(grammar.numSubStates);
		init();
		computeMappings();
	}
	
	protected void computeMappings(){
		lexiconMapping = new int[finalLevel+1][nSubstates];
		unaryMapping = new int[finalLevel+1][nSubstates][nSubstates];
		binaryMapping = new int[finalLevel+1][nSubstates][nSubstates][nSubstates];
		
		int[] divisors = new int[finalLevel+1];
		for (int i=0; i<=finalLevel; i++){
			divisors[i] = (int)Math.pow(2, finalLevel-i);
		}
		
		for (int level=1; level<=finalLevel; level++){
			int div = divisors[level];
			int l = (int)Math.pow(2,level);
			int[][] tmpU = new int[l][l];
			int[][][] tmpB = new int[l][l][l];
			int indU=0, indB=0;
			for (int i=0; i<l; i++){
				for (int j=0; j<l; j++){
					tmpU[i][j] = indU++;
					for (int k=0; k<l; k++){
						tmpB[i][j][k] = indB++;
					}
				}
			}
			for (int i=0; i<nSubstates; i++){
				lexiconMapping[level][i] = i/div;
				for (int j=0; j<nSubstates; j++){
					unaryMapping[level][i][j] = tmpU[i/div][j/div];
					for (int k=0; k<nSubstates; k++){
						binaryMapping[level][i][j][k] = tmpB[i/div][j/div][k/div];
					}
				}
			}
		}
	}
	
//	public void delinearizeSpanPredictor(double[] logProbs) {
//	
//	}

	
	public void delinearizeGrammar(double[] probs) {
		int nDangerous = 0;
		for (BinaryRule bRule : grammar.binaryRuleMap.keySet()){
			HierarchicalBinaryRule hRule = (HierarchicalBinaryRule)bRule; 
			int ind = hRule.identifier;//startIndex[ruleIndexer.indexOf(hRule)];
			double[][][] scores = hRule.getLastLevel();
			for (int j=0; j<scores.length; j++){
				for (int k=0; k<scores[j].length; k++){
					if (scores[j][k]!=null){	
						for (int l=0; l<scores[j][k].length; l++){
							double val = probs[ind++];
		  				if (SloppyMath.isVeryDangerous(val)) {
		  					nDangerous++;
		  					continue;
		  				}
							scores[j][k][l] = val;
						}
					}
				}
			}
		}
		if (nDangerous>0) System.out.println("Left "+nDangerous+" binary rule weights unchanged since the proposed weight was dangerous.");

		nDangerous = 0;
		for (UnaryRule uRule : grammar.unaryRuleMap.keySet()){
			HierarchicalUnaryRule hRule = (HierarchicalUnaryRule)uRule; 
			int ind = hRule.identifier;//startIndex[ruleIndexer.indexOf(hRule)];
			if (uRule.childState==uRule.parentState) continue;
			double[][] scores = hRule.getLastLevel();
			for (int j=0; j<scores.length; j++){
				if (scores[j]!=null){	
					for (int k=0; k<scores[j].length; k++){
						double val = probs[ind++];
	  				if (SloppyMath.isVeryDangerous(val)) {
	  					nDangerous++;
	  					continue;
	  				}
						scores[j][k] = val;
					}
				}
			}
		}
		if (nDangerous>0) System.out.println("Left "+nDangerous+" unary rule weights unchanged since the proposed weight was dangerous.");

		grammar.explicitlyComputeScores(finalLevel);
		grammar.closedSumRulesWithParent = grammar.closedViterbiRulesWithParent = grammar.unaryRulesWithParent;
		grammar.closedSumRulesWithChild = grammar.closedViterbiRulesWithChild = grammar.unaryRulesWithC;
//		computePairsOfUnaries();
		grammar.clearUnaryIntermediates();
		grammar.makeCRArrays();
//		return grammar;
	}

	public void delinearizeLexicon(double[] logProbs) {
		int nDangerous = 0;
		for (short tag=0; tag<lexicon.hierarchicalScores.length; tag++){
  		for (int word=0; word<lexicon.hierarchicalScores[tag].length; word++){
  			int index = linearIndex[tag][word];
  			double[] vals = lexicon.getLastLevel(tag,word);
  			for (int substate=0; substate<vals.length; substate++){
  				double val = logProbs[index++];
  				if (SloppyMath.isVeryDangerous(val)) {
  					nDangerous++;
						continue;
  				}
  				vals[substate] = val;
  			}
  		}
  	}  	
		if (nDangerous>0) System.out.println("Left "+nDangerous+" lexicon weights unchanged since the proposed weight was dangerous.");
		lexicon.explicitlyComputeScores(finalLevel);
//		System.out.println(lexicon);
//		return lexicon;
	}

	public double[] getLinearizedGrammar(boolean update) {
		if (update){
//			int nRules = grammar.binaryRuleMap.size() + grammar.unaryRuleMap.size();
//			startIndex = new int[nRules];
			
			nGrammarWeights = 0;
			for (BinaryRule bRule : grammar.binaryRuleMap.keySet()){
				HierarchicalBinaryRule hRule = (HierarchicalBinaryRule)bRule; 
//				ruleIndexer.add(hRule);
				if (!grammar.isGrammarTag[bRule.parentState]){ System.out.println("Incorrect grammar tag"); }
				bRule.identifier = nGrammarWeights; 
				double[][][] scores = hRule.getLastLevel();
				for (int j=0; j<scores.length; j++){
					for (int k=0; k<scores[j].length; k++){
						if (scores[j][k]!=null){
							nGrammarWeights += scores[j][k].length;
						}
					}
				}
			}
			for (UnaryRule uRule : grammar.unaryRuleMap.keySet()){
				HierarchicalUnaryRule hRule = (HierarchicalUnaryRule)uRule; 
//			  ruleIndexer.add(hRule);
//				startIndex[ruleIndexer.indexOf(uRule)] = nGrammarWeights;
				uRule.identifier = nGrammarWeights;
				double[][] scores = hRule.getLastLevel();
				for (int j=0; j<scores.length; j++){
					if (scores[j]!=null){
						nGrammarWeights += scores[j].length;
					}
				}
			}
		}
		double[] logProbs = new double[nGrammarWeights];

		for (BinaryRule bRule : grammar.binaryRuleMap.keySet()){
			HierarchicalBinaryRule hRule = (HierarchicalBinaryRule)bRule; 
			int ind = hRule.identifier;//startIndex[ruleIndexer.indexOf(hRule)];
			double[][][] scores = hRule.getLastLevel();
			for (int j=0; j<scores.length; j++){
				for (int k=0; k<scores[j].length; k++){
					if (scores[j][k]!=null){	
						for (int l=0; l<scores[j][k].length; l++){
							double val = scores[j][k][l];
							logProbs[ind++] = val;
						}
					}
				}
			}
		}

		for (UnaryRule uRule : grammar.unaryRuleMap.keySet()){
			HierarchicalUnaryRule hRule = (HierarchicalUnaryRule)uRule; 
			int ind = hRule.identifier;//startIndex[ruleIndexer.indexOf(hRule)];
			if (uRule.childState==uRule.parentState) continue;
			double[][] scores = hRule.getLastLevel();
			for (int j=0; j<scores.length; j++){
				if (scores[j]!=null){	
					for (int k=0; k<scores[j].length; k++){
						double val = scores[j][k];
						logProbs[ind++] = val;
					}
				}
			}
		}
		return logProbs;
	}

	public double[] getLinearizedLexicon(boolean update) {
  	if(update){
  		nLexiconWeights = 0;
	  	int[] substates = new int[finalLevel+1];
	  	for (int i=0; i<=finalLevel; i++) substates[i] = (int)Math.pow(2,i);
	  	for (short tag=0; tag<lexicon.hierarchicalScores.length; tag++){
	  		for (int word=0; word<lexicon.hierarchicalScores[tag].length; word++){
	  			nLexiconWeights += lexicon.getLastLevel(tag,word).length;
	  		}
	  	}
  	}
  	double[] logProbs = new double[nLexiconWeights];
  	if (update) linearIndex = new int[lexicon.hierarchicalScores.length][];
  	
  	int index = 0;
  	for (short tag=0; tag<lexicon.hierarchicalScores.length; tag++){
  		if (update) linearIndex[tag] = new int[lexicon.hierarchicalScores[tag].length];
  		for (int word=0; word<lexicon.hierarchicalScores[tag].length; word++){
  			if (update) linearIndex[tag][word] = index + nGrammarWeights;
  			double[] vals = lexicon.getLastLevel(tag,word);
  			for (int substate=0; substate<vals.length; substate++){
  				double val = vals[substate];
  				logProbs[index++] = val;
  			}
  		}
  	}
		if (index!=logProbs.length)
			System.out.println("unequal length in lexicon");

  	return logProbs;
	}

	public int getLinearIndex(int globalWordIndex, int tag){
		int tagSpecificWordIndex = lexicon.tagWordIndexer[tag].indexOf(globalWordIndex);
		if (tagSpecificWordIndex==-1) return -1;
		return linearIndex[tag][tagSpecificWordIndex];
  }


	public int dimension() {
		return nGrammarWeights + nLexiconWeights + nSpanWeights;
	}

	public void increment(double[] counts, StateSet stateSet, int tag, double[] weights, boolean isGold) {
		int globalSigIndex = stateSet.sigIndex;
		if (globalSigIndex != -1){
			int startIndexWord = getLinearIndex(globalSigIndex, tag);
			if (startIndexWord>=0){
				int finalLevel = lexicon.getFinalLevel(globalSigIndex, tag);
				for (int i=0; i<nSubstates; i++){
					if (isGold) counts[startIndexWord + lexiconMapping[finalLevel][i]] += weights[i];
					else counts[startIndexWord + lexiconMapping[finalLevel][i]] -= weights[i];
				}
			}
		}
		int globalWordIndex = stateSet.wordIndex;
		int startIndexWord = getLinearIndex(globalWordIndex, tag);
		if (startIndexWord>=0) {
			int finalLevel = lexicon.getFinalLevel(globalWordIndex, tag);
			for (int i=0; i<nSubstates; i++){
				if (isGold) counts[startIndexWord + lexiconMapping[finalLevel][i]] += weights[i];
				else counts[startIndexWord + lexiconMapping[finalLevel][i]] -= weights[i];
				weights[i]=0;
			}
		} else {
			for (int i=0; i<nSubstates; i++){
				weights[i]=0;
			}
		}
	}


	public void increment(double[] counts, UnaryRule rule, double[] weights, boolean isGold) {
		HierarchicalUnaryRule hr = (HierarchicalUnaryRule)rule;
		int thisStartIndex = hr.identifier;
		int finalLevel = hr.lastLevel;
		int curInd = 0;
		if (rule.parentState==0){
			for (int cp = 0; cp < nSubstates; cp++) {
				double val = weights[curInd];
				if (val>0){
					if (isGold) counts[thisStartIndex + lexiconMapping[finalLevel][cp]] += val;
					else counts[thisStartIndex + lexiconMapping[finalLevel][cp]] -= val;
					weights[curInd]=0;
				}
				curInd++;
			}
			return;
		}
		
		for (int cp = 0; cp < nSubstates; cp++) {
//			if (scores[cp]==null) continue; 
			for (int np = 0; np < nSubstates; np++) {
				double val = weights[curInd];
				if (val>0){
					if (isGold) counts[thisStartIndex + unaryMapping[finalLevel][cp][np]] += val;
					else counts[thisStartIndex + unaryMapping[finalLevel][cp][np]] -= val;
					weights[curInd]=0;
				}
				curInd++;
			}
		}
	}


	public void increment(double[] counts, BinaryRule rule, double[] weights, boolean isGold) {
		HierarchicalBinaryRule hr = (HierarchicalBinaryRule)rule;
		int thisStartIndex = hr.identifier;
		int finalLevel = hr.lastLevel;
		int curInd = 0;
		for (int lp = 0; lp < nSubstates; lp++) {
			for (int rp = 0; rp < nSubstates; rp++) {
//				if (scores[cp]==null) continue; 
				for (int np = 0; np < nSubstates; np++) {
					double val = weights[curInd];
					if (val>0){
						if (isGold) counts[thisStartIndex + binaryMapping[finalLevel][lp][rp][np]] += val;
						else counts[thisStartIndex + binaryMapping[finalLevel][lp][rp][np]] -= val;
						weights[curInd]=0;
					}
					curInd++;
				}
			}
		}
	}
	
	
	public Grammar getGrammar() {
		return grammar;
	}

	public SimpleLexicon getLexicon() {
		return lexicon;
	}

	public SpanPredictor getSpanPredictor() {
		return spanPredictor;
	}

	
}
